130
Applications in Natural Language Processing
The learnable weights and quantization parameters in the n-th module are updated by
minimizing the reconstruction errors. The proposed MREM can be optimized parallelly:
given previously trained modules, only weights and quantization parameters in the current
module are updated. Moreover, the number of modules N can be adjusted depending on
the memory constraint of computing resources. The flexibility of the number of transformer
layers ensures the proper trade-offbetween layer-wise correlation and memory overhead
of training devices can be achieved. Although a similar block-wise objective is previously
proposed in [137], it requires calculating second-order Hessian matrices for optimization,
which can be computationally prohibitive for large language models.
5.5.2
Model Parallel Strategy
Second, a new model parallel strategy is designed to accelerate the training process of
MREM. A common strategy is to optimize each module one by one. However, the training of
this strategy still needs a long time. Motivated by this, the authors propose a model parallel
strategy that allows all modules to be trained jointly without synchronizing with adjacent
partition modules by allocating each partitioned module to the individual computing device.
Specifically, every module is computed one after another in the first t0 step to construct
an input queue I, which contains t0 intermediate output results. For the n-th module,
its input queue comes from the previous module, i.e., It
n−1 =
f 1
n−1, f 2
n−1, f 3
n−1, . . . , f t0
n−1
.
Then, parallel training takes place. Each module samples its input from the correspondingly
input queue and optimizes the loss defined by Eq. (5.10). Meanwhile, the input queue is also
updated with the first-in-first-out rule throughout the training. Once a module produces
its output, the results will be fed into the following input queue. In the backward pass, the
gradients can propagate locally within each module, without affecting its predecessors. As
a result, such a design can avoid the load imbalance issue from straggler modules, bringing
nearly the theoretical N× speed-up if deploying in N GPU. Such results are superior to
previous data parallel [131] or model parallel [96] techniques.
5.5.3
Annealed Teacher Forcing
Third, the authors design an annealed teacher forcing for the parallel strategy. They find
that the naive parallel training suffers from the propagation of reconstruction error since
each quantized module passes the quantization error to its successors before being fully
optimized. In particular, all modules get optimized simultaneously instead of sequentially
in the parallel strategy. The next module takes the output from the input queue before
its predecessor is fully optimized. Therefore, the predecessor’s reconstruction error will
propagate to the following modules before it is sufficiently minimized. To solve this problem,
the proposed annealed teacher forcing is similar to the method in [246]. The full-precision
module provides clean signals to the next quantized module. This breaks the reconstruction
error propagation and further improves the performance of the parallel strategy. Specifically,
the output fn from the n-th full-precision module serves as the clean input to the (n+1)-th
quantized module to substitute the original ˆfn that comes from the quantized module. As
a result, fn can stop the propagation of the accumulated error on the quantized module.
Nevertheless, such an approach breaks the connection to previous quantized modules and
may suffer from forward inconsistency between training and inference for the quantized
model. To solve this problem, the actually input to (n + 1)-th quantized module is the